{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OXYgXFeMgRep"
      },
      "source": [
        "Copyright 2019 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NcIzzCADklYm"
      },
      "outputs": [],
      "source": [
        "!git clone https://github.com/google-research/google-research.git"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ngihcW7ckrDI"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "import os\n",
        "import tarfile\n",
        "import urllib\n",
        "import zipfile\n",
        "sys.path.append('./google-research')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y55h79H3XKSt"
      },
      "source": [
        "# Example of model training"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Lw1HFRd-UcTk"
      },
      "source": [
        "Below steps are taken from [model_train_eval](https://github.com/google-research/google-research/blob/master/kws_streaming/train/model_train_eval.py) - it has more tests in streaming, non streaming, quantized and non qunatized models with TF and TFLite."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fathHzuEgx8_"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yP5WBy5O8Za8"
      },
      "outputs": [],
      "source": [
        "# TF streaming\n",
        "from kws_streaming.models import models\n",
        "from kws_streaming.models import utils\n",
        "from kws_streaming.layers.modes import Modes"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wsUCmBzpk1jC"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "import tensorflow.compat.v1 as tf1\n",
        "import logging\n",
        "from kws_streaming.models import model_flags\n",
        "from kws_streaming.models import model_params\n",
        "from kws_streaming.train import test\n",
        "from kws_streaming.train import train\n",
        "from kws_streaming import data\n",
        "tf1.disable_eager_execution()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RS8tH1UvUcTu"
      },
      "outputs": [],
      "source": [
        "config = tf1.ConfigProto()\n",
        "config.gpu_options.allow_growth = True\n",
        "sess = tf1.Session(config=config)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zMdTK10tL2Dz"
      },
      "outputs": [],
      "source": [
        "# general imports\n",
        "import matplotlib.pyplot as plt\n",
        "import os\n",
        "import json\n",
        "import numpy as np\n",
        "import scipy as scipy\n",
        "import scipy.io.wavfile as wav\n",
        "import scipy.signal"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PMti09MMUcT2"
      },
      "outputs": [],
      "source": [
        "tf.__version__"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xHTcbg_ao586"
      },
      "outputs": [],
      "source": [
        "tf1.reset_default_graph()\n",
        "sess = tf1.Session()\n",
        "tf1.keras.backend.set_session(sess)\n",
        "tf1.keras.backend.set_learning_phase(0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ylPGCTPLh41F"
      },
      "source": [
        "## Set path to data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eEg-24R5UcT_"
      },
      "outputs": [],
      "source": [
        "# set PATH to data sets (for example to speech commands V2):\n",
        "# it can be downloaded from\n",
        "# https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz\n",
        "# if you already run \"00_check-data.ipynb\" then folder \"data2\" should be located in the current dir\n",
        "current_dir = os.getcwd()\n",
        "DATA_PATH = os.path.join(current_dir, \"data2/\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OObn08smUcUC"
      },
      "outputs": [],
      "source": [
        "def waveread_as_pcm16(filename):\n",
        "  \"\"\"Read in audio data from a wav file.  Return d, sr.\"\"\"\n",
        "  samplerate, wave_data = wav.read(filename)\n",
        "  # Read in wav file.\n",
        "  return wave_data, samplerate\n",
        "\n",
        "def wavread_as_float(filename, target_sample_rate=16000):\n",
        "  \"\"\"Read in audio data from a wav file.  Return d, sr.\"\"\"\n",
        "  wave_data, samplerate = waveread_as_pcm16(filename)\n",
        "  desired_length = int(\n",
        "          round(float(len(wave_data)) / samplerate * target_sample_rate))\n",
        "  wave_data = scipy.signal.resample(wave_data, desired_length)\n",
        "\n",
        "  # Normalize short ints to floats in range [-1..1).\n",
        "  data = np.array(wave_data, np.float32) / 32768.0\n",
        "  return data, target_sample_rate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TYj0JGeHhtqc"
      },
      "outputs": [],
      "source": [
        "# Set path to wav file to visualize it\n",
        "wav_file = os.path.join(DATA_PATH, \"left/012187a4_nohash_0.wav\")\n",
        "\n",
        "# read audio file\n",
        "wav_data, samplerate = wavread_as_float(wav_file)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cLAd9tfiUcUK"
      },
      "outputs": [],
      "source": [
        "assert samplerate == 16000"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "r2yeKkLsiRWJ"
      },
      "outputs": [],
      "source": [
        "plt.plot(wav_data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5_wbAZ3vhQh1"
      },
      "source": [
        "## Set path to a model with config"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3Ligfp0KUcUV"
      },
      "outputs": [],
      "source": [
        "# select model name should be one of\n",
        "model_params.HOTWORD_MODEL_PARAMS.keys()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "04bbXWx2UcUa"
      },
      "outputs": [],
      "source": [
        "# This notebook is configured to work with 'ds_tc_resnet' and 'svdf'.\n",
        "MODEL_NAME = 'ds_tc_resnet'\n",
        "# MODEL_NAME = 'svdf'\n",
        "MODELS_PATH = os.path.join(current_dir, \"models\")\n",
        "MODEL_PATH = os.path.join(MODELS_PATH, MODEL_NAME + \"/\")\n",
        "MODEL_PATH"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l9sn53qfUcUd"
      },
      "outputs": [],
      "source": [
        "# delete previously trained model with its folder and create a new one:\n",
        "os.makedirs(MODEL_PATH)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OBHIwG8FUcUg"
      },
      "outputs": [],
      "source": [
        "# get toy model settings\n",
        "FLAGS = model_params.HOTWORD_MODEL_PARAMS[MODEL_NAME]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6z7ZVz3dUcUl"
      },
      "outputs": [],
      "source": [
        "# set path to data and model (where model will be stored)\n",
        "FLAGS.data_dir = DATA_PATH\n",
        "FLAGS.train_dir = MODEL_PATH\n",
        "\n",
        "# set speech feature extractor properties\n",
        "FLAGS.mel_upper_edge_hertz = 7600\n",
        "FLAGS.window_size_ms = 30.0\n",
        "FLAGS.window_stride_ms = 10.0\n",
        "FLAGS.mel_num_bins = 80\n",
        "FLAGS.dct_num_features = 40\n",
        "FLAGS.feature_type = 'mfcc_tf'\n",
        "FLAGS.preprocess = 'raw'\n",
        "\n",
        "# for numerical correctness of streaming and non streaming models set it to 1\n",
        "# but for real use case streaming set it to 0\n",
        "FLAGS.causal_data_frame_padding = 0\n",
        "\n",
        "FLAGS.use_tf_fft = True\n",
        "FLAGS.mel_non_zero_only = not FLAGS.use_tf_fft\n",
        "\n",
        "# set training settings\n",
        "FLAGS.train = 1\n",
        "# reduced number of training steps for test only\n",
        "# so model accuracy will be low,\n",
        "# to improve accuracy set how_many_training_steps = '40000,40000,20000,20000'\n",
        "FLAGS.how_many_training_steps = '400,400,400,400'\n",
        "FLAGS.learning_rate = '0.001,0.0005,0.0001,0.00002'\n",
        "FLAGS.lr_schedule = 'linear'\n",
        "FLAGS.verbosity = logging.INFO\n",
        "\n",
        "# data augmentation parameters\n",
        "FLAGS.resample = 0.15\n",
        "FLAGS.time_shift_ms = 100\n",
        "FLAGS.use_spec_augment = 1\n",
        "FLAGS.time_masks_number = 2\n",
        "FLAGS.time_mask_max_size = 25\n",
        "FLAGS.frequency_masks_number = 2\n",
        "FLAGS.frequency_mask_max_size = 7\n",
        "FLAGS.pick_deterministically = 1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sL8zW5ADUcUp"
      },
      "outputs": [],
      "source": [
        "FLAGS.model_name = MODEL_NAME\n",
        "\n",
        "# model parameters are different for every model\n",
        "if MODEL_NAME == 'svdf':\n",
        "  FLAGS.model_name = MODEL_NAME\n",
        "  FLAGS.svdf_memory_size = \"4,10,10,10,10,10\"\n",
        "  FLAGS.svdf_units1 = \"16,32,32,32,64,128\"\n",
        "  FLAGS.svdf_act = \"'relu','relu','relu','relu','relu','relu'\"\n",
        "  FLAGS.svdf_units2 = \"40,40,64,64,64,-1\"\n",
        "  FLAGS.svdf_dropout = \"0.0,0.0,0.0,0.0,0.0,0.0\"\n",
        "  FLAGS.svdf_pad = 0\n",
        "  FLAGS.dropout1 = 0.0\n",
        "  FLAGS.units2 = ''\n",
        "  FLAGS.act2 = ''\n",
        "elif MODEL_NAME == 'ds_tc_resnet':\n",
        "  # it is an example of model streaming with strided convolution, strided pooling and dilated convolution\n",
        "  FLAGS.activation = 'relu'\n",
        "  FLAGS.dropout = 0.0\n",
        "  FLAGS.ds_filters = '128, 64, 64, 64, 128, 128'\n",
        "  FLAGS.ds_filter_separable = '1, 1, 1, 1, 1, 1'\n",
        "  FLAGS.ds_repeat = '1, 1, 1, 1, 1, 1'\n",
        "  FLAGS.ds_residual = '0, 1, 1, 1, 0, 0' # residual can not be applied with stride\n",
        "#   FLAGS.ds_kernel_size = '11, 5, 15, 7, 29, 1'\n",
        "  FLAGS.ds_kernel_size = '11, 5, 15, 17, 15, 1'\n",
        "  FLAGS.ds_dilation = '1, 1, 1, 1, 2, 1'\n",
        "  FLAGS.ds_stride = '1, 1, 1, 1, 1, 1'\n",
        "  FLAGS.ds_pool = '1, 2, 1, 1, 1, 1'\n",
        "  # model should be causal, so that we can covert it to streaming mode\n",
        "  # if model is non causal then all non causal components should use Delay layer\n",
        "  FLAGS.ds_padding = \"'causal', 'causal', 'causal', 'causal', 'causal', 'causal'\"\n",
        "else:\n",
        "  raise ValueError(\"set parameters for other models\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W7gI-9jTzb7M"
      },
      "outputs": [],
      "source": [
        "FLAGS.clip_duration_ms = 1000  # standard audio file in this data set has 1 sec length\n",
        "FLAGS.batch_size = 100"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Tt6a1m8RUcUs"
      },
      "outputs": [],
      "source": [
        "flags = model_flags.update_flags(FLAGS)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7vXrn3szUcUy"
      },
      "outputs": [],
      "source": [
        "with open(os.path.join(flags.train_dir, 'flags.json'), 'wt') as f:\n",
        "  json.dump(flags.__dict__, f)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8GWgpluSUcU1"
      },
      "outputs": [],
      "source": [
        "# visualize a model\n",
        "model_non_stream_batch = models.MODELS[flags.model_name](flags)\n",
        "tf.keras.utils.plot_model(\n",
        "    model_non_stream_batch,\n",
        "    show_shapes=True,\n",
        "    show_layer_names=True,\n",
        "    expand_nested=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PGbLrABNUcU5"
      },
      "outputs": [],
      "source": [
        "model_non_stream_batch.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sc3yteQJUcU8"
      },
      "source": [
        "## Model training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EXSDOpQ_UcU9"
      },
      "outputs": [],
      "source": [
        "# Model training\n",
        "train.train(flags)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RIr1DWLisMu9"
      },
      "source": [
        "## Run model evaluation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "456ynjRxmdVc"
      },
      "source": [
        "### TF Run non streaming inference"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3e_X-fKdUcVC"
      },
      "outputs": [],
      "source": [
        "folder_name = 'tf'\n",
        "test.tf_non_stream_model_accuracy(flags, folder_name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3dyT1fQcUcVF"
      },
      "source": [
        "more testing functions can be found at [test](https://github.com/google-research/google-research/blob/master/kws_streaming/train/test.py)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "01_train.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
